{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Pipelines"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Pipelines are an integral part of River. We encourage their usage and apply them in many of their examples.\n",
    "\n",
    "The `compose.Pipeline` contains all the logic for building and applying pipelines. A pipeline is essentially a list of estimators that are applied in sequence. The only requirement is that the first `n - 1` steps be transformers. The last step can be a regressor, a classifier, a clusterer, a transformer, etc.\n",
    "\n",
    "Here is an example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-12-04T17:49:52.864965Z",
     "iopub.status.busy": "2023-12-04T17:49:52.864389Z",
     "iopub.status.idle": "2023-12-04T17:49:53.144412Z",
     "shell.execute_reply": "2023-12-04T17:49:53.144042Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "from river import compose\n",
    "from river import linear_model\n",
    "from river import preprocessing\n",
    "from river import feature_extraction\n",
    "\n",
    "model = compose.Pipeline(\n",
    "    preprocessing.StandardScaler(),\n",
    "    feature_extraction.PolynomialExtender(),\n",
    "    linear_model.LinearRegression()\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can also use the `|` operator, as so:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-12-04T17:49:53.146332Z",
     "iopub.status.busy": "2023-12-04T17:49:53.146215Z",
     "iopub.status.idle": "2023-12-04T17:49:53.154883Z",
     "shell.execute_reply": "2023-12-04T17:49:53.154651Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "model = (\n",
    "    preprocessing.StandardScaler() |\n",
    "    feature_extraction.PolynomialExtender() |\n",
    "    linear_model.LinearRegression()\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Or, equally:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-12-04T17:49:53.156607Z",
     "iopub.status.busy": "2023-12-04T17:49:53.156514Z",
     "iopub.status.idle": "2023-12-04T17:49:53.165628Z",
     "shell.execute_reply": "2023-12-04T17:49:53.165348Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "model = preprocessing.StandardScaler() \n",
    "model |= feature_extraction.PolynomialExtender()\n",
    "model |= linear_model.LinearRegression()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A pipeline, as any River estimator, has a `_repr_html_` method, which can be used to visualize it in Jupyter-like notebooks:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-12-04T17:49:53.167122Z",
     "iopub.status.busy": "2023-12-04T17:49:53.167045Z",
     "iopub.status.idle": "2023-12-04T17:49:53.180439Z",
     "shell.execute_reply": "2023-12-04T17:49:53.180202Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div><div class=\"river-component river-pipeline\"><details class=\"river-component river-estimator\"><summary class=\"river-summary\"><pre class=\"river-estimator-name\">StandardScaler</pre></summary><code class=\"river-estimator-params\">StandardScaler (\n",
       "  with_std=True\n",
       ")\n",
       "</code></details><details class=\"river-component river-estimator\"><summary class=\"river-summary\"><pre class=\"river-estimator-name\">PolynomialExtender</pre></summary><code class=\"river-estimator-params\">PolynomialExtender (\n",
       "  degree=2\n",
       "  interaction_only=False\n",
       "  include_bias=False\n",
       "  bias_name=\"bias\"\n",
       ")\n",
       "</code></details><details class=\"river-component river-estimator\"><summary class=\"river-summary\"><pre class=\"river-estimator-name\">LinearRegression</pre></summary><code class=\"river-estimator-params\">LinearRegression (\n",
       "  optimizer=SGD (\n",
       "    lr=Constant (\n",
       "      learning_rate=0.01\n",
       "    )\n",
       "  )\n",
       "  loss=Squared ()\n",
       "  l2=0.\n",
       "  l1=0.\n",
       "  intercept_init=0.\n",
       "  intercept_lr=Constant (\n",
       "    learning_rate=0.01\n",
       "  )\n",
       "  clip_gradient=1e+12\n",
       "  initializer=Zeros ()\n",
       ")\n",
       "</code></details></div><style scoped>\n",
       ".river-estimator {\n",
       "    padding: 1em;\n",
       "    border-style: solid;\n",
       "    background: white;\n",
       "    max-width: max-content;\n",
       "}\n",
       "\n",
       ".river-pipeline {\n",
       "    display: flex;\n",
       "    flex-direction: column;\n",
       "    align-items: center;\n",
       "    background: linear-gradient(#000, #000) no-repeat center / 1.5px 100%;\n",
       "}\n",
       "\n",
       ".river-union {\n",
       "    display: flex;\n",
       "    flex-direction: row;\n",
       "    align-items: center;\n",
       "    justify-content: center;\n",
       "    padding: 1em;\n",
       "    border-style: solid;\n",
       "    background: white;\n",
       "}\n",
       "\n",
       ".river-wrapper {\n",
       "    display: flex;\n",
       "    flex-direction: column;\n",
       "    align-items: center;\n",
       "    justify-content: center;\n",
       "    padding: 1em;\n",
       "    border-style: solid;\n",
       "    background: white;\n",
       "}\n",
       "\n",
       ".river-wrapper > .river-estimator {\n",
       "    margin-top: 1em;\n",
       "}\n",
       "\n",
       "/* Vertical spacing between steps */\n",
       "\n",
       ".river-component + .river-component {\n",
       "    margin-top: 2em;\n",
       "}\n",
       "\n",
       ".river-union > .river-estimator {\n",
       "    margin-top: 0;\n",
       "}\n",
       "\n",
       ".river-union > .river-component {\n",
       "    margin-top: 0;\n",
       "}\n",
       "\n",
       ".river-union > .pipeline {\n",
       "    margin-top: 0;\n",
       "}\n",
       "\n",
       "/* Spacing within a union of estimators */\n",
       "\n",
       ".river-union > .river-component + .river-component {\n",
       "    margin-left: 1em;\n",
       "}\n",
       "\n",
       "/* Typography */\n",
       "\n",
       ".river-estimator-params {\n",
       "    display: block;\n",
       "    white-space: pre-wrap;\n",
       "    font-size: 110%;\n",
       "    margin-top: 1em;\n",
       "}\n",
       "\n",
       ".river-estimator > .river-estimator-params,\n",
       ".river-wrapper > .river-details > river-estimator-params {\n",
       "    background-color: white !important;\n",
       "}\n",
       "\n",
       ".river-wrapper > .river-details {\n",
       "    margin-bottom: 1em;\n",
       "}\n",
       "\n",
       ".river-estimator-name {\n",
       "    display: inline;\n",
       "    margin: 0;\n",
       "    font-size: 110%;\n",
       "}\n",
       "\n",
       "/* Toggle */\n",
       "\n",
       ".river-summary {\n",
       "    display: flex;\n",
       "    align-items:center;\n",
       "    cursor: pointer;\n",
       "}\n",
       "\n",
       ".river-summary > div {\n",
       "    width: 100%;\n",
       "}\n",
       "</style></div>"
      ],
      "text/plain": [
       "Pipeline (\n",
       "  StandardScaler (\n",
       "    with_std=True\n",
       "  ),\n",
       "  PolynomialExtender (\n",
       "    degree=2\n",
       "    interaction_only=False\n",
       "    include_bias=False\n",
       "    bias_name=\"bias\"\n",
       "  ),\n",
       "  LinearRegression (\n",
       "    optimizer=SGD (\n",
       "      lr=Constant (\n",
       "        learning_rate=0.01\n",
       "      )\n",
       "    )\n",
       "    loss=Squared ()\n",
       "    l2=0.\n",
       "    l1=0.\n",
       "    intercept_init=0.\n",
       "    intercept_lr=Constant (\n",
       "      learning_rate=0.01\n",
       "    )\n",
       "    clip_gradient=1e+12\n",
       "    initializer=Zeros ()\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`compose.Pipeline` implements a `learn_one` method which in sequence calls the `learn_one` of each component and a `predict_one` (resp `predict_proba_one`) method which calls `transform_one` on the first `n - 1` steps and `predict_one` (resp `predict_proba_one`) on the last step.\n",
    "\n",
    "Here is a small example to illustrate the previous point:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-12-04T17:49:53.181786Z",
     "iopub.status.busy": "2023-12-04T17:49:53.181717Z",
     "iopub.status.idle": "2023-12-04T17:49:53.319962Z",
     "shell.execute_reply": "2023-12-04T17:49:53.319706Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "({'ordinal_date': 736389,\n",
       "  'gallup': 43.843213,\n",
       "  'ipsos': 46.19925042857143,\n",
       "  'morning_consult': 48.318749,\n",
       "  'rasmussen': 44.104692,\n",
       "  'you_gov': 43.636914000000004},\n",
       " 43.75505)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from river import datasets\n",
    "\n",
    "dataset = datasets.TrumpApproval()\n",
    "x, y = next(iter(dataset))\n",
    "x, y"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can predict the target value of a new sample by calling the `predict_one` method, however, by default, `predict_one` does not update any model parameter, therefore the predictions will be 0 and the model parameters will remain the default values (0 for `StandardScaler` component):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "model.predict_one(x)=0.00, y=43.76\n",
      "model['StandardScaler'].means = defaultdict(<class 'float'>, {'ordinal_date': 0.0, 'gallup': 0.0, 'ipsos': 0.0, 'morning_consult': 0.0, 'rasmussen': 0.0, 'you_gov': 0.0})\n",
      "model.predict_one(x)=0.00, y=43.71\n",
      "model['StandardScaler'].means = defaultdict(<class 'float'>, {'ordinal_date': 0.0, 'gallup': 0.0, 'ipsos': 0.0, 'morning_consult': 0.0, 'rasmussen': 0.0, 'you_gov': 0.0})\n"
     ]
    }
   ],
   "source": [
    "for (x, y) in dataset.take(2):\n",
    "    print(f\"{model.predict_one(x)=:.2f}, {y=:.2f}\")\n",
    "    print(f\"{model['StandardScaler'].means = }\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`learn_one` updates pipeline stateful steps, parameters and the prediction change:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "model.predict_one(x)=0.88, y=43.76\n",
      "model['StandardScaler'].means = defaultdict(<class 'float'>, {'ordinal_date': 736389.0, 'gallup': 43.843213, 'ipsos': 46.19925042857143, 'morning_consult': 48.318749, 'rasmussen': 44.104692, 'you_gov': 43.636914000000004})\n",
      "model.predict_one(x)=9.44, y=43.71\n",
      "model['StandardScaler'].means = defaultdict(<class 'float'>, {'ordinal_date': 736389.5, 'gallup': 43.843213, 'ipsos': 46.19925042857143, 'morning_consult': 48.318749, 'rasmussen': 45.104692, 'you_gov': 42.636914000000004})\n"
     ]
    }
   ],
   "source": [
    "for (x, y) in dataset.take(2):\n",
    "    model.learn_one(x, y)\n",
    "\n",
    "    print(f\"{model.predict_one(x)=:.2f}, {y=:.2f}\")\n",
    "    print(f\"{model['StandardScaler'].means = }\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Each component of the pipeline has been updated with the new data point. \n",
    "\n",
    "A pipeline is a very powerful tool that can be used to chain together multiple steps in a machine learning workflow.\n",
    "\n",
    "Notice that it is also possible to call `transform_one` with a pipeline, this method will run `transform_one` of each transformer in it, and return the result of the last transformer (which is thus the penultimate step if the last step is a _predictor_ or _clusterer_, while it is the last step if the last step is a _transformer_):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-12-04T17:49:53.332637Z",
     "iopub.status.busy": "2023-12-04T17:49:53.332562Z",
     "iopub.status.idle": "2023-12-04T17:49:53.341944Z",
     "shell.execute_reply": "2023-12-04T17:49:53.341509Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'ordinal_date': 1.0,\n",
       " 'gallup': 0.0,\n",
       " 'ipsos': 0.0,\n",
       " 'morning_consult': 0.0,\n",
       " 'rasmussen': 1.0,\n",
       " 'you_gov': -1.0,\n",
       " 'ordinal_date*ordinal_date': 1.0,\n",
       " 'gallup*ordinal_date': 0.0,\n",
       " 'ipsos*ordinal_date': 0.0,\n",
       " 'morning_consult*ordinal_date': 0.0,\n",
       " 'ordinal_date*rasmussen': 1.0,\n",
       " 'ordinal_date*you_gov': -1.0,\n",
       " 'gallup*gallup': 0.0,\n",
       " 'gallup*ipsos': 0.0,\n",
       " 'gallup*morning_consult': 0.0,\n",
       " 'gallup*rasmussen': 0.0,\n",
       " 'gallup*you_gov': -0.0,\n",
       " 'ipsos*ipsos': 0.0,\n",
       " 'ipsos*morning_consult': 0.0,\n",
       " 'ipsos*rasmussen': 0.0,\n",
       " 'ipsos*you_gov': -0.0,\n",
       " 'morning_consult*morning_consult': 0.0,\n",
       " 'morning_consult*rasmussen': 0.0,\n",
       " 'morning_consult*you_gov': -0.0,\n",
       " 'rasmussen*rasmussen': 1.0,\n",
       " 'rasmussen*you_gov': -1.0,\n",
       " 'you_gov*you_gov': 1.0}"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.transform_one(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In many cases, you might want to connect a step to multiple steps. For instance, you might to extract different kinds of features from a single input. An elegant way to do this is to use a `compose.TransformerUnion`. Essentially, the latter is a list of transformers who's results will be merged into a single `dict` when `transform_one` is called.\n",
    "\n",
    "As an example let's say that we want to apply a `feature_extraction.RBFSampler` as well as the `feature_extraction.PolynomialExtender`. This may be done as so:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-12-04T17:49:53.357462Z",
     "iopub.status.busy": "2023-12-04T17:49:53.357331Z",
     "iopub.status.idle": "2023-12-04T17:49:53.369477Z",
     "shell.execute_reply": "2023-12-04T17:49:53.369159Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div><div class=\"river-component river-pipeline\"><details class=\"river-component river-estimator\"><summary class=\"river-summary\"><pre class=\"river-estimator-name\">StandardScaler</pre></summary><code class=\"river-estimator-params\">StandardScaler (\n",
       "  with_std=True\n",
       ")\n",
       "</code></details><div class=\"river-component river-union\"><details class=\"river-component river-estimator\"><summary class=\"river-summary\"><pre class=\"river-estimator-name\">PolynomialExtender</pre></summary><code class=\"river-estimator-params\">PolynomialExtender (\n",
       "  degree=2\n",
       "  interaction_only=False\n",
       "  include_bias=False\n",
       "  bias_name=\"bias\"\n",
       ")\n",
       "</code></details><details class=\"river-component river-estimator\"><summary class=\"river-summary\"><pre class=\"river-estimator-name\">RBFSampler</pre></summary><code class=\"river-estimator-params\">RBFSampler (\n",
       "  gamma=1.\n",
       "  n_components=100\n",
       "  seed=None\n",
       ")\n",
       "</code></details></div><details class=\"river-component river-estimator\"><summary class=\"river-summary\"><pre class=\"river-estimator-name\">LinearRegression</pre></summary><code class=\"river-estimator-params\">LinearRegression (\n",
       "  optimizer=SGD (\n",
       "    lr=Constant (\n",
       "      learning_rate=0.01\n",
       "    )\n",
       "  )\n",
       "  loss=Squared ()\n",
       "  l2=0.\n",
       "  l1=0.\n",
       "  intercept_init=0.\n",
       "  intercept_lr=Constant (\n",
       "    learning_rate=0.01\n",
       "  )\n",
       "  clip_gradient=1e+12\n",
       "  initializer=Zeros ()\n",
       ")\n",
       "</code></details></div><style scoped>\n",
       ".river-estimator {\n",
       "    padding: 1em;\n",
       "    border-style: solid;\n",
       "    background: white;\n",
       "    max-width: max-content;\n",
       "}\n",
       "\n",
       ".river-pipeline {\n",
       "    display: flex;\n",
       "    flex-direction: column;\n",
       "    align-items: center;\n",
       "    background: linear-gradient(#000, #000) no-repeat center / 1.5px 100%;\n",
       "}\n",
       "\n",
       ".river-union {\n",
       "    display: flex;\n",
       "    flex-direction: row;\n",
       "    align-items: center;\n",
       "    justify-content: center;\n",
       "    padding: 1em;\n",
       "    border-style: solid;\n",
       "    background: white;\n",
       "}\n",
       "\n",
       ".river-wrapper {\n",
       "    display: flex;\n",
       "    flex-direction: column;\n",
       "    align-items: center;\n",
       "    justify-content: center;\n",
       "    padding: 1em;\n",
       "    border-style: solid;\n",
       "    background: white;\n",
       "}\n",
       "\n",
       ".river-wrapper > .river-estimator {\n",
       "    margin-top: 1em;\n",
       "}\n",
       "\n",
       "/* Vertical spacing between steps */\n",
       "\n",
       ".river-component + .river-component {\n",
       "    margin-top: 2em;\n",
       "}\n",
       "\n",
       ".river-union > .river-estimator {\n",
       "    margin-top: 0;\n",
       "}\n",
       "\n",
       ".river-union > .river-component {\n",
       "    margin-top: 0;\n",
       "}\n",
       "\n",
       ".river-union > .pipeline {\n",
       "    margin-top: 0;\n",
       "}\n",
       "\n",
       "/* Spacing within a union of estimators */\n",
       "\n",
       ".river-union > .river-component + .river-component {\n",
       "    margin-left: 1em;\n",
       "}\n",
       "\n",
       "/* Typography */\n",
       "\n",
       ".river-estimator-params {\n",
       "    display: block;\n",
       "    white-space: pre-wrap;\n",
       "    font-size: 110%;\n",
       "    margin-top: 1em;\n",
       "}\n",
       "\n",
       ".river-estimator > .river-estimator-params,\n",
       ".river-wrapper > .river-details > river-estimator-params {\n",
       "    background-color: white !important;\n",
       "}\n",
       "\n",
       ".river-wrapper > .river-details {\n",
       "    margin-bottom: 1em;\n",
       "}\n",
       "\n",
       ".river-estimator-name {\n",
       "    display: inline;\n",
       "    margin: 0;\n",
       "    font-size: 110%;\n",
       "}\n",
       "\n",
       "/* Toggle */\n",
       "\n",
       ".river-summary {\n",
       "    display: flex;\n",
       "    align-items:center;\n",
       "    cursor: pointer;\n",
       "}\n",
       "\n",
       ".river-summary > div {\n",
       "    width: 100%;\n",
       "}\n",
       "</style></div>"
      ],
      "text/plain": [
       "Pipeline (\n",
       "  StandardScaler (\n",
       "    with_std=True\n",
       "  ),\n",
       "  TransformerUnion (\n",
       "    PolynomialExtender (\n",
       "      degree=2\n",
       "      interaction_only=False\n",
       "      include_bias=False\n",
       "      bias_name=\"bias\"\n",
       "    ),\n",
       "    RBFSampler (\n",
       "      gamma=1.\n",
       "      n_components=100\n",
       "      seed=None\n",
       "    )\n",
       "  ),\n",
       "  LinearRegression (\n",
       "    optimizer=SGD (\n",
       "      lr=Constant (\n",
       "        learning_rate=0.01\n",
       "      )\n",
       "    )\n",
       "    loss=Squared ()\n",
       "    l2=0.\n",
       "    l1=0.\n",
       "    intercept_init=0.\n",
       "    intercept_lr=Constant (\n",
       "      learning_rate=0.01\n",
       "    )\n",
       "    clip_gradient=1e+12\n",
       "    initializer=Zeros ()\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = (\n",
    "    preprocessing.StandardScaler() |\n",
    "    (feature_extraction.PolynomialExtender() + feature_extraction.RBFSampler()) |\n",
    "    linear_model.LinearRegression()\n",
    ")\n",
    "\n",
    "model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that the `+` symbol acts as a shorthand notation for creating a `compose.TransformerUnion`, which means that we could have declared the above pipeline as so:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-12-04T17:49:53.371204Z",
     "iopub.status.busy": "2023-12-04T17:49:53.371072Z",
     "iopub.status.idle": "2023-12-04T17:49:53.381553Z",
     "shell.execute_reply": "2023-12-04T17:49:53.381303Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "model = (\n",
    "    preprocessing.StandardScaler() |\n",
    "    compose.TransformerUnion(\n",
    "        feature_extraction.PolynomialExtender(),\n",
    "        feature_extraction.RBFSampler()\n",
    "    ) |\n",
    "    linear_model.LinearRegression()\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Pipelines provide the benefit of removing a lot of cruft by taking care of tedious details for you. They also enable to clearly define what steps your model is made of.\n",
    "\n",
    "Finally, having your model in a single object means that you can move it around more easily.\n",
    "\n",
    "Note that you can include user-defined functions in a pipeline by using a `compose.FuncTransformer`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Learning during predict\n",
    "\n",
    "In online machine learning, we can update the unsupervised parts of our model when a sample arrives. We don't _really_ have to wait for the ground truth to arrive in order to update unsupervised estimators that don't depend on it.\n",
    "\n",
    "In other words, in a pipeline, `learn_one` updates the supervised parts, whilst `predict_one` (or `predict_proba_one` for that matter) **can** update the unsupervised parts, which often yields better results. \n",
    "\n",
    "In river, we can achieve this behavior using a dedicated context manager: `compose.learn_during_predict`.\n",
    "\n",
    "Here is the same example as before, with the only difference of activating the such learning during predict behavior:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = (\n",
    "    preprocessing.StandardScaler() |\n",
    "    feature_extraction.PolynomialExtender() |\n",
    "    linear_model.LinearRegression()\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "model.predict_one(x)=0.00, y=43.76\n",
      "model['StandardScaler'].means = defaultdict(<class 'float'>, {'ordinal_date': 736389.0, 'gallup': 43.843213, 'ipsos': 46.19925042857143, 'morning_consult': 48.318749, 'rasmussen': 44.104692, 'you_gov': 43.636914000000004})\n",
      "model.predict_one(x)=0.00, y=43.71\n",
      "model['StandardScaler'].means = defaultdict(<class 'float'>, {'ordinal_date': 736389.5, 'gallup': 43.843213, 'ipsos': 46.19925042857143, 'morning_consult': 48.318749, 'rasmussen': 45.104692, 'you_gov': 42.636914000000004})\n"
     ]
    }
   ],
   "source": [
    "with compose.learn_during_predict():\n",
    "    for (x, y) in dataset.take(2):\n",
    "\n",
    "        print(f\"{model.predict_one(x)=:.2f}, {y=:.2f}\")\n",
    "        print(f\"{model['StandardScaler'].means = }\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Calling `predict_one` within this context will update each transformer of the pipeline. For instance here we can see that the mean of each feature of the standard scaler step have been updated.\n",
    "\n",
    "On the other hand, the supervised part of our pipeline, the linear regression, has not been updated or learned anything yet. Hence the prediction on any sample will be nil because each weight is still equal to 0."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.0, {})"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.predict_one(x), model[\"LinearRegression\"].weights"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Performance Comparison\n",
    "\n",
    "One may wonder what is the advantage of learning during predict. Let's compare the performance of a pipeline with and without learning during predict, in two scenarios: one in which the flow of data stays the same, we just update "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "from contextlib import nullcontext\n",
    "from river import metrics\n",
    "\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "def score_pipeline(learn_during_predict: bool, n_learning_samples: int | None = None) -> float:\n",
    "    \"\"\"Scores a pipeline on the TrumpApproval dataset.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    learn_during_predict : bool\n",
    "        Whether or not to learn the unsupervided components during the prediction step.\n",
    "        If False it will only learn when `learn_one` is explicitly called.\n",
    "    n_learning_samples : int | None \n",
    "        Number of samples used to `learn_one`.\n",
    "\n",
    "    Return\n",
    "    ------\n",
    "    MAE : float\n",
    "        Mean absolute error of the pipeline on the dataset\n",
    "    \"\"\"\n",
    "\n",
    "    dataset = datasets.TrumpApproval()\n",
    "\n",
    "    model = (\n",
    "        preprocessing.StandardScaler() |\n",
    "        linear_model.LinearRegression()\n",
    "        )\n",
    "\n",
    "    metric = metrics.MAE()\n",
    "\n",
    "    ctx = compose.learn_during_predict if learn_during_predict else nullcontext\n",
    "    n_learning_samples = n_learning_samples or dataset.n_samples\n",
    "\n",
    "    with ctx():\n",
    "        for _idx, (x, y) in enumerate(dataset):\n",
    "            y_pred = model.predict_one(x)\n",
    "\n",
    "            metric.update(y, y_pred)\n",
    "            \n",
    "            if _idx < n_learning_samples:\n",
    "                model.learn_one(x, y)\n",
    "\n",
    "    return metric.get()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_samples = datasets.TrumpApproval().n_samples\n",
    "\n",
    "results = [\n",
    "    {\n",
    "        \"learn_during_predict\": learn_during_predict,\n",
    "        \"pct_learning_samples\": round(100*n_learning_samples/max_samples, 0),\n",
    "        \"mae\": score_pipeline(learn_during_predict=learn_during_predict, n_learning_samples=n_learning_samples)\n",
    "    }\n",
    "    for learn_during_predict in (True, False)\n",
    "    for n_learning_samples in range(max_samples, max_samples//10, -(max_samples//10))\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style type=\"text/css\">\n",
       "</style>\n",
       "<table id=\"T_059da\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th class=\"index_name level0\" >learn_during_predict</th>\n",
       "      <th id=\"T_059da_level0_col0\" class=\"col_heading level0 col0\" >False</th>\n",
       "      <th id=\"T_059da_level0_col1\" class=\"col_heading level0 col1\" >True</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th class=\"index_name level0\" >pct_learning_samples</th>\n",
       "      <th class=\"blank col0\" >&nbsp;</th>\n",
       "      <th class=\"blank col1\" >&nbsp;</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th id=\"T_059da_level0_row0\" class=\"row_heading level0 row0\" >100.0%</th>\n",
       "      <td id=\"T_059da_row0_col0\" class=\"data row0 col0\" >1.314548</td>\n",
       "      <td id=\"T_059da_row0_col1\" class=\"data row0 col1\" >1.347434</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_059da_level0_row1\" class=\"row_heading level0 row1\" >90.0%</th>\n",
       "      <td id=\"T_059da_row1_col0\" class=\"data row1 col0\" >1.629333</td>\n",
       "      <td id=\"T_059da_row1_col1\" class=\"data row1 col1\" >1.355274</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_059da_level0_row2\" class=\"row_heading level0 row2\" >80.0%</th>\n",
       "      <td id=\"T_059da_row2_col0\" class=\"data row2 col0\" >2.712125</td>\n",
       "      <td id=\"T_059da_row2_col1\" class=\"data row2 col1\" >1.371599</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_059da_level0_row3\" class=\"row_heading level0 row3\" >70.0%</th>\n",
       "      <td id=\"T_059da_row3_col0\" class=\"data row3 col0\" >4.840620</td>\n",
       "      <td id=\"T_059da_row3_col1\" class=\"data row3 col1\" >1.440773</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_059da_level0_row4\" class=\"row_heading level0 row4\" >60.0%</th>\n",
       "      <td id=\"T_059da_row4_col0\" class=\"data row4 col0\" >8.918634</td>\n",
       "      <td id=\"T_059da_row4_col1\" class=\"data row4 col1\" >1.498240</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_059da_level0_row5\" class=\"row_heading level0 row5\" >50.0%</th>\n",
       "      <td id=\"T_059da_row5_col0\" class=\"data row5 col0\" >15.112753</td>\n",
       "      <td id=\"T_059da_row5_col1\" class=\"data row5 col1\" >1.878434</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_059da_level0_row6\" class=\"row_heading level0 row6\" >40.0%</th>\n",
       "      <td id=\"T_059da_row6_col0\" class=\"data row6 col0\" >26.387331</td>\n",
       "      <td id=\"T_059da_row6_col1\" class=\"data row6 col1\" >2.105553</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_059da_level0_row7\" class=\"row_heading level0 row7\" >30.0%</th>\n",
       "      <td id=\"T_059da_row7_col0\" class=\"data row7 col0\" >42.997083</td>\n",
       "      <td id=\"T_059da_row7_col1\" class=\"data row7 col1\" >3.654709</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_059da_level0_row8\" class=\"row_heading level0 row8\" >20.0%</th>\n",
       "      <td id=\"T_059da_row8_col0\" class=\"data row8 col0\" >90.703102</td>\n",
       "      <td id=\"T_059da_row8_col1\" class=\"data row8 col1\" >3.504950</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_059da_level0_row9\" class=\"row_heading level0 row9\" >10.0%</th>\n",
       "      <td id=\"T_059da_row9_col0\" class=\"data row9 col0\" >226.836953</td>\n",
       "      <td id=\"T_059da_row9_col1\" class=\"data row9 col1\" >4.803600</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n"
      ],
      "text/plain": [
       "<pandas.io.formats.style.Styler at 0x7f9c57bcbfd0>"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(pd.DataFrame(results)\n",
    " .pivot(columns=\"learn_during_predict\", index=\"pct_learning_samples\", values=\"mae\")\n",
    " .sort_index(ascending=False)\n",
    " .style.format_index('{0}%')\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As we can see from the resulting table above, the scores are comparable only in the case in which the percentage of learning samples above 90%. After that the score starts to degrade quite fast as the percentage of learning samples decreases, and it is very remarkable (one order of magnitude or more) when less than 50% of the samples are used for learning.\n",
    "\n",
    "Although a simple case, this examplify how powerful it can be to learn unsupervised components during predict."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
